package com.mumux.framework.mvc.core;

import com.mumux.framework.mvc.annotation.Controller;
import com.mumux.framework.mvc.annotation.RequestMapping;
import com.mumux.framework.mvc.resolver.ViewResolver;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 前端控制器
 * @author huangjuncong
 *
 */
public class DispatcherServlet extends HttpServlet {

    //IOC容器，保存Controller实例对象
    private static final ConcurrentHashMap<String,Object> mvcBeanMap = new ConcurrentHashMap<String,Object>();
    //保存handler映射
    private static final ConcurrentHashMap<String, Method> handlerMapping = new ConcurrentHashMap<String,Method>();
    //自定视图解析器
    private ViewResolver myViewResolver;

    @Override
    public void init(ServletConfig config) throws ServletException {
        // TODO Auto-generated method stub
        //扫描Controller，创建实例对象，并存入mvcBeanMap
        scanController(config);
        //初始化handler映射
        initHandlerMapping();
        //加载视图解析器
        loadViewResolver(config);
    }
    /**
     * 扫描Controller
     * @param config
     */
    public void scanController(ServletConfig config){
        SAXReader reader = new SAXReader();
        try {
            //解析springmvc.xml
            String path = config.getServletContext().getRealPath("")+"/WEB-INF/classes/"+config.getInitParameter("contextConfigLocation");
            Document document = reader.read(path);
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element ele = (Element) iter.next();
                if(ele.getName().equals("component-scan")){
                    String packageName = ele.attributeValue("base-package");
                    //获取base-package包下的所有类名
                    List<String> list = getClassNames(packageName);
                    for(String str:list){
                        Class clazz = Class.forName(str);
                        //判断是否有Controller注解
                        if(clazz.isAnnotationPresent(Controller.class)){
                            //获取Controller中RequestMapping注解的value
                            RequestMapping annotation = (RequestMapping) clazz.getAnnotation(RequestMapping.class);
                            String value = annotation.value().substring(1);
                            mvcBeanMap.put(value, clazz.newInstance());
                        }
                    }
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    /**
     * 获取包下的所有类名
     * @param scanPath
     * @return
     */
    public  List<String> getClassNames(String scanPath){
        // 第一个class类的集合
        List<String> classNames = new ArrayList<String>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageDirName = scanPath.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()) {
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    scanClassesInPackageByFile(scanPath, filePath, recursive, classNames);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return classNames;

    }

    /**
     * 以文件的形式来获取包下的所有Class
     *
     * @param scanPath
     * @param packagePath
     * @param recursive
     * @param classNames
     */
    public static void scanClassesInPackageByFile(String scanPath, String packagePath, final boolean recursive,
                                                  List<String> classNames) {
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter() {
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            @Override
            public boolean accept(File file) {
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                scanClassesInPackageByFile(scanPath + "." + file.getName(), file.getAbsolutePath(), recursive,
                        classNames);
            } else {
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);

                // 添加到集合中去
                classNames.add(scanPath + '.' + className);

            }
        }
    }


    /**
     * 初始化handler映射
     */
    public void initHandlerMapping(){
        for(String str:mvcBeanMap.keySet()){
            Class clazz = mvcBeanMap.get(str).getClass();
            Method[] methods = clazz.getMethods();
            for (Method method : methods) {
                //判断方式是否添加MyRequestMapping注解
                if(method.isAnnotationPresent(RequestMapping.class)){
                    //获取Method中MyRequestMapping注解的value
                    RequestMapping annotation = method.getAnnotation(RequestMapping.class);
                    String value = annotation.value().substring(1);
                    //method存入methodMapping
                    handlerMapping.put(value, method);
                }
            }
        }
    }

    /**
     * 加载自定义视图解析器
     * @param config
     */
    public void loadViewResolver(ServletConfig config){
        SAXReader reader = new SAXReader();
        try {
            //解析springmvc.xml
            String path = config.getServletContext().getRealPath("")+"/WEB-INF/classes/"+config.getInitParameter("contextConfigLocation");
            Document document = reader.read(path);
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element ele = (Element) iter.next();
                if(ele.getName().equals("bean")){
                    String className = ele.attributeValue("class");
                    Class clazz = Class.forName(className);
                    Object obj = clazz.newInstance();
                    //获取setter方法
                    Method prefixMethod = clazz.getMethod("setPrefix", String.class);
                    Method suffixMethod = clazz.getMethod("setSuffix", String.class);
                    Iterator beanIter = ele.elementIterator();
                    //获取property值
                    Map<String,String> propertyMap = new HashMap<String,String>();
                    while(beanIter.hasNext()){
                        Element beanEle = (Element) beanIter.next();
                        String name = beanEle.attributeValue("name");
                        String value = beanEle.attributeValue("value");
                        propertyMap.put(name, value);
                    }
                    for(String str:propertyMap.keySet()){
                        //反射机制调用setter方法，完成赋值。
                        if(str.equals("prefix")){
                            prefixMethod.invoke(obj, propertyMap.get(str));
                        }
                        if(str.equals("suffix")){
                            suffixMethod.invoke(obj, propertyMap.get(str));
                        }
                    }
                    myViewResolver = (ViewResolver) obj;
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        // TODO Auto-generated method stub
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        // TODO Auto-generated method stub
        //获取请求
        String handlerUri = req.getRequestURI().split("/")[1];
        //获取Controller实例
        Object obj = mvcBeanMap.get(handlerUri);
        String methodUri = req.getRequestURI().split("/")[2];
        //获取业务方法
        Method method = handlerMapping.get(methodUri);
        try {
            //反射机制调用业务方法
            String value = (String) method.invoke(obj);
            //视图解析器将逻辑视图转换为物理视图
            String result = myViewResolver.jspMapping(value);
            //页面跳转
            req.getRequestDispatcher(result).forward(req, resp);
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
}
